import os

import pytest
import test_converter

import nle.env.tasks
from nle.dataset import db
from nle.dataset import populate_db

TTYRECS = [
    ("aaa", test_converter.TTYREC_2020),
    ("bbb", test_converter.TTYREC_IBMGRAPHICS),
    ("ccc", test_converter.TTYREC_DECGRAPHICS),
]

HOUR = 3600
YEAR = 365 * 24 * HOUR


def gen_ttyrecs(savedir):
    for _ in range(2):
        env = nle.env.tasks.NetHackChallenge(
            savedir=savedir, character="mon-hum-neu-mal", save_ttyrec_every=2
        )
        env.reset()
        for _ in range(5):
            # Need to end naturally to be recorded.
            for c in [ord(" "), ord(" "), ord("<"), ord("y")]:
                _, _, done, *_ = env.step(env.actions.index(c))
            assert done
            env.reset()
        env.close()


@pytest.fixture(scope="session")
def mockdata(tmpdir_factory):  # Create mock data.
    """This fixture needs to be imported to generate a fake db.

    It generates 1 database with 3 datasets in it:

    1. "basictest" - test basic construction of db
        * 3 users "aaa", "bbb", "ccc",
        * each user with 3 copies of the same ttyrec ("[aaa|bbb|ccc].ttyrec.bz2")
        * user "aaa" and "ccc" have the same ttyrecs, "bbb"'s is different
        * user "ccc" has all of their ttyrecs registered as just one "game"/episode
    2. "altorgtest" - test basic construction of db from
            nle/tests/altorg subdirectory
    3. "nletest" - test basic construction of db from data generated by nle
    """
    basetemp = tmpdir_factory.getbasetemp()

    paths = []
    for user, ttyrec in TTYRECS:
        if basetemp.join(user).exists():
            break
        with open(test_converter.getfilename(ttyrec), "rb") as f:
            rec = f.read()
        d = tmpdir_factory.mktemp(user, numbered=False)
        for fn, _ in TTYRECS:
            fn = d.join(fn + ".ttyrec.bz2")
            paths.append(str(fn))
            fn.ensure()
            fn.write(rec, "wb")

    oldcwd = os.getcwd()
    newcwd = tmpdir_factory.getbasetemp()

    try:
        os.chdir(newcwd)
        if not newcwd.join("ttyrecs.db").exists():
            db.create("ttyrecs.db")

            # 1.  Add "basictest" dataset
            with db.db(rw=True) as conn:
                dataset_name = "basictest"
                c = conn.cursor()
                root = str(newcwd)
                paths = [os.path.relpath(p, root) for p in paths]
                games = [1, 2, 3, 4, 5, 6, 7, 7, 7]
                parts = [0, 0, 0, 0, 0, 0, 0, 1, 2]

                c.execute(
                    (
                        "INSERT INTO roots (root, dataset_name, ttyrec_version) "
                        "VALUES (?,?,1)"
                    ),
                    (root, dataset_name),
                )
                c.executemany(
                    "INSERT INTO datasets (gameid, dataset_name) VALUES (?, ?)",
                    zip(range(1, 8), [dataset_name] * 8),
                )
                c.executemany(
                    "INSERT INTO ttyrecs (path, part, gameid, size) VALUES (?,?, ?,?)",
                    zip(paths, parts, games, range(100, 1000, 100)),
                )
                c.executemany(
                    "INSERT INTO games (gameid, name, death, points) VALUES (?,?,?,?)",
                    [
                        (1, "aaa", "met a sticky end", 0),
                        (2, "aaa", "got into a pickle", 1),
                        (3, "aaa", "ran into trouble", 2),
                        (4, "bbb", "met a sticky end", 3),
                        (5, "bbb", "got into a pickle", 4),
                        (6, "bbb", "ran into trouble", 5),
                        (7, "ccc", "ascended", 999),
                    ],
                )
                conn.commit()

            # 2. Add "altorgtest" datase
            populate_db.add_altorg_directory(
                test_converter.getfilename("altorg"), "altorgtest"
            )
            db.set_root("altorgtest", "/path/to/altorg/")

            # 3. Generate and Add "nletest" dataset
            # NB: Uncomment when adding nledata tests
            gen_ttyrecs("")
            populate_db.add_nledata_directory(
                str(os.path.join(root, "nle_data")), "nletest"
            )

        yield tmpdir_factory.getbasetemp()
    finally:
        os.chdir(oldcwd)


@pytest.fixture(scope="session")
def conn(mockdata):
    """This fixture needs to be imported to receive a connection to working db."""
    with db.db() as conn:
        yield conn


class TestDB:
    def test_conn(self, conn):
        assert conn

    def test_ls(self, conn):
        for i, row in enumerate(db.ls(conn)):
            rowid, path = row
            assert int(rowid) == i + 1
        assert i + 1 == 28

    def test_vacuum(self):
        db.vacuum()

    def test_getrow(self):
        row1 = db.get_row("1")
        row5 = db.get_row("5")
        row8 = db.get_row("9")

        assert len(row8) == 6
        rowids, paths, parts, sizes, mtimes, *_ = zip(row1, row5, row8)
        assert rowids == (1, 5, 9)
        # Order of glob (via os.listdir) OS dependent.
        names, ttyrecs = list(zip(*reversed(TTYRECS)))
        suffix = ".ttyrec.bz2"
        for p in paths:
            assert os.path.dirname(p) in names
            assert os.path.basename(p) in [t + suffix for t in names]
        assert sizes == (100, 500, 900)

    def test_getmeta(self, conn):
        with conn:
            meta = db.get_meta(conn)
        ctime, mtime = meta

        assert 0 < mtime - ctime < 45

    def test_setroot(self):
        dataset_name = "basictest"
        oldroot = db.get_root(dataset_name)
        newroot = "/my/new/root"
        db.set_root(dataset_name, newroot)
        assert db.get_root(dataset_name) == newroot
        db.set_root(dataset_name, oldroot)
        assert db.get_root(dataset_name) == oldroot

    def test_games(self, conn):
        user_games = {"aaa": [], "bbb": [], "ccc": []}
        for user, gameid, start, end in db.get_games("basictest"):
            user_games[user].append((gameid, start, end))

        assert len(user_games["aaa"]) == 3
        assert len(user_games["bbb"]) == 3
        assert len(user_games["ccc"]) == 1

        cmd = "SELECT * FROM ttyrecs WHERE gameid = ? ORDER BY part"
        for gameid, _, _ in user_games["aaa"] + user_games["bbb"]:
            for ttyrec in conn.execute(cmd, (gameid,)).fetchall():
                assert ttyrec[1] == 0

        for gameid, _, _ in user_games["ccc"]:
            for i, ttyrec in enumerate(conn.execute(cmd, (gameid,)).fetchall()):
                assert ttyrec[1] == i

    def test_version(self):
        # Expect adding nledata to provide latest ttyrec_version
        ttyrec_version = db.get_ttyrec_version("nletest")
        assert ttyrec_version == nle.nethack.TTYREC_VERSION

        # Expect adding altorg data additions to give version 1
        ttyrec_version = db.get_ttyrec_version("altorgtest")
        assert ttyrec_version == 1

        # Expect basictest using old ttyrecs to give version 1
        ttyrec_version = db.get_ttyrec_version("basictest")
        assert ttyrec_version == 1
